{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using TensorFlow version: 2.0.0-beta1\n",
      "GPU Available: False\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "from src.lstm import SimpleLSTM\n",
    "from src.metrics import exact_match_metric\n",
    "from src.callbacks import NValidationSetsCallback\n",
    "from src.generators import DataGeneratorSeq\n",
    "from src.utils import get_sequence_data\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "\n",
    "print(f\"Using TensorFlow version: {tf.__version__}\")\n",
    "print(f\"GPU Available: {tf.test.is_gpu_available()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "SETTINGS = Path('../settings/')\n",
    "DATA = Path('../data/processed/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "settings_path = Path(SETTINGS/'settings_local.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(str(settings_path), 'r') as file:\n",
    "    settings_dict = json.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'math_module': 'arithmetic__add_sub',\n",
       " 'train_level': '*',\n",
       " 'batch_size': 1024,\n",
       " 'thinking_steps': 16,\n",
       " 'epochs': 1,\n",
       " 'num_encoder_units': 512,\n",
       " 'num_decoder_units': 2048,\n",
       " 'embedding_dim': 2048,\n",
       " 'save_path': '../data/',\n",
       " 'data_path': '../data/'}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "settings_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_gen_pars, input_texts, target_texts = get_sequence_data(settings_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of training samples: 1599998\n"
     ]
    }
   ],
   "source": [
    "print('Number of training samples:', len(input_texts['train']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of validation samples: 10000\n"
     ]
    }
   ],
   "source": [
    "print('Number of validation samples:', len(input_texts['interpolate']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['batch_size', 'max_encoder_seq_length', 'max_decoder_seq_length', 'max_seq_length', 'num_encoder_tokens', 'num_decoder_tokens', 'num_tokens', 'input_token_index', 'target_token_index', 'token_index', 'num_thinking_steps'])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_gen_pars.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data generators"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_generator = DataGeneratorSeq(\n",
    "    input_texts=input_texts[\"train\"],\n",
    "    target_texts=target_texts[\"train\"],\n",
    "    **data_gen_pars\n",
    ")\n",
    "validation_generator = DataGeneratorSeq(\n",
    "    input_texts=input_texts[\"valid\"],\n",
    "    target_texts=target_texts[\"valid\"],\n",
    "    **data_gen_pars\n",
    ")\n",
    "interpolate_generator = DataGeneratorSeq(\n",
    "    input_texts=input_texts[\"interpolate\"],\n",
    "    target_texts=target_texts[\"interpolate\"],\n",
    "    **data_gen_pars\n",
    ")\n",
    "extrapolate_generator = DataGeneratorSeq(\n",
    "    input_texts=input_texts[\"extrapolate\"],\n",
    "    target_texts=target_texts[\"extrapolate\"],\n",
    "    **data_gen_pars\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         [(None, None, 34)]        0         \n",
      "_________________________________________________________________\n",
      "lstm (LSTM)                  [(None, None, 2048), (Non 17063936  \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, None, 34)          69666     \n",
      "=================================================================\n",
      "Total params: 17,133,602\n",
      "Trainable params: 17,133,602\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "lstm = SimpleLSTM(data_gen_pars['num_tokens'], settings_dict['embedding_dim'])\n",
    "model = lstm.get_model()\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "adam = Adam(\n",
    "    lr=6e-4,\n",
    "    beta_1=0.9,\n",
    "    beta_2=0.995,\n",
    "    epsilon=1e-9,\n",
    "    decay=0.0,\n",
    "    amsgrad=False,\n",
    "    clipnorm=0.1,\n",
    ")\n",
    "model.compile(\n",
    "    optimizer=adam, loss=\"categorical_crossentropy\", metrics=[exact_match_metric]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure callbacks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_dict = {\n",
    "    'validation':validation_generator,\n",
    "    'interpolation': interpolate_generator,\n",
    "    'extrapolation': extrapolate_generator\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "history = NValidationSetsCallback(valid_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# directory where the checkpoints will be saved\n",
    "checkpoint_dir = settings_dict[\"save_path\"] + \"training_checkpoints\"\n",
    "# name of the checkpoint files\n",
    "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
    "\n",
    "checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(\n",
    "    filepath=checkpoint_prefix, save_weights_only=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_hist = model.fit_generator(\n",
    "    training_generator,\n",
    "    epochs=settings_dict[\"epochs\"],\n",
    "    callbacks=[history, checkpoint_callback],\n",
    "    verbose=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
